import pySALEPlot as psp

import matplotlib.pyplot as plt
import numpy             as np
import matplotlib.colors as mcolors

from mpl_toolkits.axes_grid1 import AxesGrid
from matplotlib.offsetbox    import AnchoredText
from matplotlib.ticker       import AutoMinorLocator
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

################ Input/Output files and directories and general plotting parameters ################
flat = True              # Make flat or curved plot?
stacked_plot = False     # Stack the subplots if they're from the same model/field (True). Set to false
                         # if you want to plot data from different models
mirror_fields = False    # Do we want to plot two columns with the same data from the same model (True), or 
                         # a single column with two different data fields/time steps/etc.?
 

dirname  = 'ring_formation/paper_figs'  # Output directory
figname  = 'fig5_tracer_tps_flat_v1'           # Figure name
fig_type = 'pdf'                        # What type of figure (e.g., png, pdf, eps)
dpi      = 1000                         # Resolution of figure

# iSALE data file(s)
model_files = ['ring_formation/35Gcrst_120imp_30kpkm_1400_liqcor_a/jdata.dat',  
               'ring_formation/35Gcrst_120imp_30kpkm_1400_liqcor_b/jdata.dat',
               'ring_formation/35Gcrst_120imp_30kpkm_1400_liqcor_c/jdata.dat']


# Model specific data (time steps, field(s), etc.)
m           = 0                      # which model file to start at 
dt          = [18, 32, 40, 52, 145]                  # Time steps to plot
letters     = ['a', 'b', 'c', 'd', 'e', 'f', 'g']   # Letters for subpanels
field_type  = 'TPS'                       # Must be one of the iSALE saved field variables
field_label = 'Total plastic strain'      # Field label for colorbar

min_field   = 0                            # Minimum value for field variable (TPS in this case)
max_field   = 10                           # Maximum value for field variable (TPS in this case)
field_ticks = [0,2,4,6,8,10]               # Where to place field ticks on colorbar
depth_ticks = [0,5,10,15,20,25,30,35]      # Where to place depth ticks on colorbar (units of km)
vel_ticks   =  [-350, -250, -150, -50, 0, 50, 150, 250, 350]

# These are only used when plotting tracer depths
crust  = 35.     # preimpact crustal thickness (km)
mantle = 420.    # depth to core-mantle boundary (km)
radius = 2240.   # radius of planet (km)

# Plotting parameters
fig_width  = 5#7.25     # Figure width (in)
fig_height = 6#9.5      # Figure height (in)
nrows      = 2#len(dt)  # Number of rows in figure
ncols      = 1        # Number of columns in figure

fs         = 8        # font size for all text in plot
alp        = 1        # transparency of dots in scatter plot
sz         = 1        # size of dots in scatter plot
ve         = 2        # vertical exaggeration factor for flat plots

# Colormaps and contour line colors for field plot(s)
css4 = mcolors.CSS4_COLORS['brown']                          # Custom color for 'mantle.set_over'
cmap_vel    = plt.get_cmap('coolwarm')                       # Color map for tracer velocities
cmap_field  = plt.get_cmap('magma_r')                        # Color map for field variable
cmap_crust  = plt.get_cmap('viridis_r', len(depth_ticks)-1)  # Color map for crust
cmap_mantle = plt.get_cmap('bone_r', 19.25)                  # Color map for mantle
cmap_mantle.set_over('black')                                # If over, color this (mantle)  
cont_color  = 'gray'                                         # Color to contour material boundaries
                                                             # in the field plot

######################################### Define functions #########################################

# Associate model files with desired time steps as they're spread out after using 
# iSALE's restart function
def modelLocator(model_files, dt, m = 0):

  m = m
  model_locator = np.empty(shape = (len(dt), 3))

  # Find the start and end time of the initial model file
  model = psp.opendatfile(model_files[m], verbose = False)
  [start, end, num] = findTimeSteps(model, convert_time = True)

  for i in np.arange(0, len(dt)):

    if int(dt[i]) > int(end):

      m = m + 1

      model = psp.opendatfile(model_files[m], verbose = False)
      [start, end, num] = findTimeSteps(model, convert_time = True)

    timestep = findModelStep(dt[i], start, end, num)

    model_locator[i] = [dt[i], m, timestep]

  return model_locator

# Create a figure
def createFigure(nrows, ncols, fig_width, fig_height, direction = 'column', flat = False):
  """
  Creates a figure panel using the ImageGrid approach to prevent subplots from getting distorted

  INPUTS:
    nrows, ncols  --  rows and columns of the grid
    fig_width(height) -- how wide (high) we want the figure in inches
    cbar              -- whether we want a colorbar created by ImageGrid (default == False)

  OUTPUTS:
    fig, grid  -- figure and grid instances
  """

  fig  = plt.figure(figsize = (fig_width, fig_height))   # Set up a figure

  if stacked_plot:    # Classic mirrored plot for iSALE
    grid = AxesGrid(fig, 111, nrows_ncols = (nrows, ncols), axes_pad = (0,0), direction = direction)
  
  else:               # Stacking two subplots with different field variables  
    grid = AxesGrid(fig, 111, nrows_ncols = (nrows, ncols), axes_pad = (0, 0.2), direction = direction,
                    cbar_mode = 'edge', cbar_location = 'right', cbar_pad = 0.1, cbar_size = "3%")

  for k in range(0, len(grid)):

    if flat:
      grid[k].set_aspect(ve)
      grid[k].set_xticks([0,200,400,600,800,1000])
      grid[k].set_yticks([160, 120, 80, 40, 0, -40])
      grid[k].tick_params(axis = 'both', which = 'both', direction = 'in',
                    top = True, right = True, labelsize = fs)
      grid[k].yaxis.set_ticklabels(['160','120','80','40','0', ''])
      grid[k].xaxis.set_minor_locator(AutoMinorLocator(2))
      grid[k].yaxis.set_minor_locator(AutoMinorLocator(2))

      # Set limits for axes
      grid[k].set_xlim([0, 1100])
      grid[k].set_ylim([160, -40])

    else:
      grid[k].set_aspect('equal')
      grid[k].set_xticks([0, 200, 400, 600, 800, 1000])
      grid[k].set_yticks([-300, -200, -100, 0, 100])
      grid[k].tick_params(axis = 'both', which = 'both', direction = 'in',
                          top = True, right = True, labelsize = fs)
      grid[k].xaxis.set_minor_locator(AutoMinorLocator(2))
      
      # Set limits for axes
      grid[k].set_xlim([0, 1000])
      grid[k].set_ylim([-350, 50])

    if ncols > 1 and k < len(grid)/2: # Only need to this if we're making a mirrored plot
      grid[k].invert_xaxis()

  return fig, grid

# Finds start and end times and total number of steps in a given model file
def findTimeSteps(model, convert_time = True):
  """ Finds the start and end simulation times for a given model file.

  Inputs: model_file -- an iSALE model file that has already been loaded into the program
          convert_time --  Logical flag (TRUE, FALSE) for converting time from seconds to mins.

  Outputs: start_time  --  Start time (seconds [False] or minutes[True])
           end_time    --  End time (seconds or minutes)
           num_steps   --  Number of model steps
  """

  # Go through the model files to determine where the start and stop times are located. 
  # If they're not in the bounds of the model file, go to the next one in the list.

  # Get the number of time steps from the model file (-1 since indexing starts at 0)
  num_steps = int((model.nsteps) - 1)

  # Get the start time of the model
  start_step = model.readStep(['TPS'], int(0))
  start_time = start_step.time

  # Get the end time of the model
  end_step = model.readStep(['TPS'], num_steps)
  end_time = end_step.time

  # Convert from seconds to minutes if desired
  if convert_time == True:
    start_time = start_time/60.
    end_time   = end_time/60.

  return start_time, end_time, num_steps 

# Find time step of model that corresponds with a desired simulation time 
def findModelStep(time_step, start_time, end_time, num_steps):
  """ 
  Finds the time step of model that is associated with the desired model time

  Inputs: model_file -- name of the model file
          time_steps -- vector of model times

  output: model_step -- time step that is associated with a particular model time

  """
  # Find the time interval between model steps and round to the nearest minute

  interval = round(num_steps/end_time)

  # Special case when iSALE's restart function is used and the model is spread across multiple 
  # 'jdata.dat' files.
  if start_time > 0.:

    model_step = int(interval * (time_step - int(start_time)))

  else:

    # If only using one model file or we're at the first model file of many
    model_step = int(interval * time_step)

  return model_step

# Find tracer depths at initial time step and simplify syntax for tracer (x,y)
def tracerDepths(model0, step0, model, step, u, Rp):

    x0 = step0.xmark[model0.tru[u].start:model0.tru[u].end]
    y0 = step0.ymark[model0.tru[u].start:model0.tru[u].end]

    x  = step.xmark[model.tru[u].start:model.tru[u].end]
    y  = step.ymark[model.tru[u].start:model.tru[u].end]

    y_shift = y0 + Rp
    r       = np.sqrt(y_shift**2 + x0 ** 2)
    depth   = Rp - r

    return x, y, depth

# Convert tracers to flat coordinates
def tracerFlatten(model0, model, step0, step, u, Rp):
  """
  Converts tracer coordinates from a spherical to planar coordinate system.

  INPUTS: 
    Rp     -- Radius of planet
    model0 -- Model file with the initial time step (t = 0 s)
    model  -- Model file with other time step(s)
    step0  -- Step from t = 0 s
    step   -- Step from t = t1
    u      -- Tracer array (see pySALEPlot documentation for more details)

  OUTPUTS:
    x_flat, y_flat, depth -- x & y coordinates and the depth of each tracer particle

  """
  # Shift y-coordinates so Rp origin is at center of planet 
  y0_shift = step0.ymark[model0.tru[u].start:model0.tru[u].end] + Rp
  y_shift  = step.ymark[model.tru[u].start:model.tru[u].end] + Rp

  # Calculate radii of tracers...
  r  = np.sqrt( (step.xmark[model.tru[u].start:model.tru[u].end]**2) + (y_shift**2) )
  r0 = np.sqrt( (step0.xmark[model0.tru[u].start:model0.tru[u].end]**2) + (y0_shift**2) )

  # ...and the angle between them and the symmetry axis.
  theta = np.arctan((step.xmark[model.tru[u].start:model.tru[u].end]) / y_shift)

  # Now get their flat coordinates
  x_flat = r * theta
  y_flat = Rp - r
  depth  = Rp - r0

  return x_flat, y_flat, depth

# Convert centroids to flat coordinates
def centroidFlatten(model, Rp, mask = 999999):
  """
  Converts centroid coordinates from curved model to planar coordinates

  INPUTS:
    model -- Model file
    Rp    -- Radius of planet
    mask  -- Dummy value for masking array. Otherwise, values from bottom half
             of planet will overprint the main data.

  OUTPUTS:
    xc_flat -- X-centroid coordinates 
    yc_flat -- Y-centroid coordinates
  """
  # Shift the data, calculate the radius, and get the angle
  yc_shift = model.yc + Rp
  r        = np.sqrt(model.xc**2 + yc_shift**2)
  theta    = np.arctan(model.xc / yc_shift)

  # Calculate their flat coordinates using dummy values so we can mask the array
  x_temp = r * theta
  y_temp = Rp - r

  xc_flat = np.where(theta > 0., x_temp, mask)
  yc_flat = np.where(theta > 0., y_temp, mask)

  return xc_flat, yc_flat

# Convert centroids to flat coordinates
def verticiesFlatten(model, Rp, mask = 999999):
  """
  Converts cell verticies from curved model to planar coordinates

  INPUTS:
    model -- Model file
    Rp    -- Radius of planet
    mask  -- Dummy value for masking array. Otherwise, values from bottom half
             of planet will overprint the main data.

  OUTPUTS:
    x_flat -- Verticies in x-direction 
    y_flat -- Verticies in y-direction
  """
  # Shift the data, calculate the radius, and get the angle
  y_shift0 = model.y + Rp
  y_shift = np.where(model.y < Rp, y_shift0, mask)
  r        = np.sqrt(model.x**2 + y_shift**2)
  theta    = np.arctan(model.x / y_shift)

  # Calculate their flat coordinates using dummy values so we can mask the array
  x_temp = r * theta
  y_temp = Rp - r

  x_flat = np.where(theta >= -0.5, x_temp, mask)
  y_flat = np.where(theta >= -0.5, y_temp, mask)

  return x_flat, y_flat

# Calculate horizontal (toward/away from basin center) velocity profile
def tracerVelocity(x1, x2, time1, time2):
  """
	Calculates the material velocity for a given basin. This function assumes
  the coordinate transformation has already been done. Must be called after
  'tracerFlatten()'.

	INPUTS:
		model     -- Model file to use for calculation 
		time#     -- Sequential time steps where time2 > time1 and come from model.step[n+1] and 
					 model.step[n]. Units of seconds.
		 
	OUTPUTS:
		velocities -- 2D array of tracer velocities
	"""

  # Now that the coordinate transformation is done, the displacement calculation 
  # is straightforward.
  x_disp  = x2 - x1
  delta_t = time2 - time1   # Time interval (s)

  velocities = x_disp / delta_t   # Velocity array (m/s) where negative is toward basin center
  return velocities

# Plot tracer velocities 
def plotVelocities(idx):
  """
  Plots tracer velocities for curved models that been transformed to planar coordinates.

  INPUTS:
    idx -- index for assigning to a figure panel
  OUTPUTS:
    p1 -- field for assigning colorbar

  """
  # Change spatial scale to meters
  model.setScale('m')
  model0.setScale('m')

  # Get the steps and associated model times
  # dup_step0 = duplicate_model0.readStep('TPS', 0)
  step1 = model.readStep('TPS', int(model_locator[i,2])-1)
  step2 = model.readStep('TPS', int(model_locator[i,2]))
  time1 = step1.time
  time2 = step2.time

  # Go through all the tracer units and calculate their velocities
  for u in range(1,4):

    [x1, y1, d1] = tracerFlatten(model0, model, step0, step1, u, radius*1e3)
    [x2, y2, d2] = tracerFlatten(model0, model, step0, step2, u, radius*1e3)

    velocities = tracerVelocity(x1, x2, time1, time2)

    if flat:
      [xc, yc] = centroidFlatten(model, radius*1e3)
      xc = xc/1e3
      yc = yc/1e3
      x2 = x2/1e3
      y2 = y2/1e3

    else:
      [x2, y2, dummy] = tracerDepths(model0, step0, model, step2, u, radius*1e3)
      x2 = x2/1e3
      y2 = y2/1e3
      xc = model.xc/1e3
      yc = model.yc/1e3

    p1 = grid[idx].scatter(x2, y2, c = velocities, cmap = cmap_vel,
                    vmin = min(vel_ticks), vmax = max(vel_ticks), s = 1, 
                    linewidths = 0, rasterized = True)
    
  for mm in [0,2]:

    grid[idx].contour(xc, yc, step2.cmc[mm], 1, colors = 'k', linewidths = 1)

  # Label each panel with time step & letter for figure call outs
  at1 = AnchoredText("{:3.0f} min".format(step.time/60.), prop = dict(size = fs-1),
                    frameon = False, loc = 'upper right')        
  grid[idx].add_artist(at1)

  at2 = AnchoredText("{}".format(letters[idx]), prop = dict(size = fs+1, fontweight = "bold"),
                frameon = False, loc = 'lower left', bbox_to_anchor=(-0.15, 0.75),
                bbox_transform=grid[idx].transAxes)
  grid[idx].add_artist(at2)
  
  return p1
    
# Plot tracer depths for curved models
def plotTracers(idx, flat = False):

  # Plot mantle material
  for u in range(2,4):

    if flat:
      # Flatten tracer coordinates
      [x, y, depth] = tracerFlatten(model0, model, step0, step, u, radius)
      # Flatten centroid arrays
      [xc, yc] = centroidFlatten(model, radius, 999999)

    else:
      [x, y, depth] = tracerDepths(model0, step0, model, step, u, radius)
      xc = model.xc
      yc = model.yc

    p3 = grid[idx].scatter(x, y, c = depth, cmap = cmap_mantle, vmin = crust + 1, vmax = mantle,
                    s = sz, linewidths = 0, rasterized = True)

  # Plot crustal material
  for u in range(1,2):

    if flat:
      # Flatten tracer coordinates
      [x, y, depth] = tracerFlatten(model0, model, step0, step, u, radius)
      # Flatten centroid arrays
      [xc, yc] = centroidFlatten(model, radius, 999999)

    else:
      [x, y, depth] = tracerDepths(model0, step0, model, step, u, radius)
      xc = model.xc
      yc = model.yc

    p2 = grid[idx].scatter(x, y, c = depth, cmap = cmap_crust, vmin = 0, vmax = crust, s = sz, 
                    linewidths = 0, rasterized = True)

    # Contour material boundaries
    for mm in [0,2]:

      grid[idx].contour(xc, yc, step.cmc[mm], 1, colors = 'k', linewidths = 0.75)
      # grid[idx].contour(model.xc, model.yc, step.cmc[mm], 1, colors = 'k', linewidths = 0.75)

  # Label each panel with its time step & letter for figure call outs
  at1 = AnchoredText("{:3.0f} min".format(step.time/60.), prop = dict(size = fs-1),
                     frameon = False, loc = 'upper right')
  # at1.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
  grid[idx].add_artist(at1)

  return p2

# Plot field data for curved models
def plotField(idx, flat = False):

  if flat:

    # Flatten coordinates
    [x_cord, y_cord] = verticiesFlatten(model, radius)

    # Contour material boundaries
    [xc, yc] = centroidFlatten(model, radius)

    # Add vertical axis labels & invert horizontal axis
    grid[idx].set_ylabel('Depth (km)', fontsize = fs)

  else:
    x_cord = model.x
    y_cord = model.y
    xc = model.xc
    yc = model.yc

    # Add vertical axis labels & invert horizontal axis
    grid[idx].set_ylabel('Height (km)', fontsize = fs)

  # Plot field 
  p1 = grid[idx].pcolormesh(x_cord, y_cord, step.data[0],
                          cmap = cmap_field, vmin = min_field, vmax = max_field,
                          shading = 'auto', rasterized = True)

  for mm in [0,2]:

    grid[idx].contour(xc, yc, step.cmc[mm],
                    1, colors = cont_color, linewidths = 0.75)

  # Label panel with its timestep
  at1 = AnchoredText("{:3.0f} min".format(step.time/60.), prop = dict(size = fs-1),
                     frameon = False, loc = 'upper right')
  grid[idx].add_artist(at1)

  # # Add subpanel letters
  # at2 = AnchoredText("{}".format(letters[idx]), prop = dict(size = fs+1, fontweight = "bold"),
  #                frameon = False, loc = 'lower left', bbox_to_anchor=(-0.15, 0.75),
  #                bbox_transform=grid[idx].transAxes)
  # grid[idx].add_artist(at2)

  return p1

# Create colorbar(s)
def makeCbars(axs, field, ticks, label):

  if stacked_plot:
    axins = inset_axes(axs, width = "85%", height = "15%", loc = 'upper center',
                   bbox_to_anchor = (0, 0.7, 1, 0.5), bbox_transform = axs.transAxes,
                   borderpad = 0)
    
    cb = fig.colorbar(field, cax = axins, ticks = ticks, orientation = 'horizontal')
    cb.ax.set_xlabel(label, fontsize = fs-1)
    cb.ax.tick_params(direction = 'in', which = 'both', labelsize = fs-1, length = 5, bottom = True)
    cb.ax.xaxis.set_ticks_position('top')
    cb.ax.xaxis.set_label_position('top')

  else:
    axins = inset_axes(axs, width = "5%", height = "75%", loc = 'lower left',
                   bbox_to_anchor = (1.05, 0, 1, 1), bbox_transform = axs.transAxes,
                   borderpad = 0)
    
    cb = fig.colorbar(field, cax = axins, ticks = ticks, orientation = 'vertical')
    cb.ax.set_xlabel(label, fontsize = fs-1)
    cb.ax.tick_params(direction = 'in', which = 'both', labelsize = fs-1, length = 6, bottom = True)
    cb.ax.yaxis.set_ticks_position('right')
    cb.ax.yaxis.set_label_position('right')

  return 

######################################## Begin main program ########################################
if __name__ == '__main__':


  # Make the figure output directory
  psp.mkdir_p(dirname) 

  # Get data from first time step of the first model file and set the spatial scale
  model0 = psp.opendatfile(model_files[0], verbose = False)
  model0.setScale('km')
  step0  = model0.readStep(field_type, 0)                  

  # Load model file(s). If the model is separated into multiple jdata.dat files, determine which one is 
  # associated with the given time step and save to an array. Otherwise, this model will pull data from 
  # the single model file while ensuring plot times are at the minute mark(s) prescribed in 'dt'
  model_locator = modelLocator(model_files, dt, m)

  # Create a figure
  fig, grid = createFigure(nrows, ncols, fig_width, fig_height, flat = flat)

  if stacked_plot:
    if mirror_fields:

      # Go through each panel in 'grid' and plot the data mirror_fields
      for j in np.arange(0, ncols):

        # First column  --- desired field
        # Second column --- tracers

        # Loop over rows
        for i in np.arange(0, nrows):

          # Setting the distance units to km
          model = psp.opendatfile(model_files[int(model_locator[i,1])], verbose = False)
          model.setScale('km')
          step = model.readStep([field_type], int(model_locator[i,2]))

          # If in first column, plot field variable
          if j == 0:

            # Plot field data
            p1 = plotField(i, flat = flat)

            # Add color bar
            if i == 0:
              
              makeCbars(grid[i], p1, field_ticks, label = field_label)

            # Add horizontal axis label
            if i == len(dt)-1:
              grid[i].set_xlabel('Radial distance (km)', fontsize = fs)

          # Plot tracer depths
          if j == 1:

            # Initialize a new index since grid == nrows * ncols
            idx = i + len(dt)
            p1 = plotTracers(idx, flat = flat)

            # Add colorbar
            if i == 0:
              
              makeCbars(grid[idx], p1, depth_ticks, label = 'Initial depth (km)')

            # Add horizontal axis label
            if i == len(dt)-1:
              grid[idx].set_xlabel('Radial distance (km)', fontsize = fs)

    # This will plot multiple stacked panels in a single column (mirror_fields == Flase)
    else:
      
      for i in np.arange(0, np.shape(model_locator)[0]):

        model = psp.opendatfile(model_files[int(model_locator[i,1])], verbose = False)
        model.setScale('km')
        step = model.readStep([field_type], int(model_locator[i,2]))

        # Set horizontal axis label and plot the data
        grid[i].set_ylabel('Depth (km)', fontsize = fs)
        p1 = plotVelocities(i)

        # Add vertical exaggeration label
        at3 = AnchoredText("V.E. = {}:1".format(ve), prop = dict(size = fs-1), frameon = True,
                         loc = 'lower right')
        grid[i].add_artist(at3)

        # Make colorbar
        if i == 0:
          makeCbars(grid[i], p1, ticks = vel_ticks, label = 'Velocity relative to basin center (m/s)')
        
        # Add horizontal axis label
        if i == len(dt) - 1:
          grid[i].set_xlabel('Radial distance (km)', fontsize = fs)

  # Plotting multiple fields in a single column from different fields of the same model
  else:

    fig, grid = createFigure(nrows, ncols, fig_width, fig_height, flat = flat)

    for i in np.arange(0, nrows):

      # Plot first data set
      model = psp.opendatfile(model_files[int(model_locator[0,1])], verbose = False)
      model.setScale('km')
      step = model.readStep([field_type], int(model_locator[0,2]))
      grid[i].set_ylabel('Depth (km)', fontsize = fs)

      # Add subpanel letters
      at2 = AnchoredText("{}".format(letters[i]), prop = dict(size = fs+1, fontweight = "bold"),
                    frameon = False, loc = 'lower left', bbox_to_anchor=(-0.12, 0.8),
                    bbox_transform=grid[i].transAxes)
      grid[i].add_artist(at2)

      # Add vertical exaggeration label
      at3 = AnchoredText("V.E. = {}:1".format(int(ve)), prop = dict(size = fs-1), frameon = True,
                         loc = 'lower right')
      grid[i].add_artist(at3)

      if i == 0:
        p1 = plotTracers(i, flat)
        
        # Add color bar 
        cbar = grid[i].cax.colorbar(p1)
        cbar = grid.cbar_axes[i].colorbar(p1)
        cbar.ax.tick_params(direction = 'in', which = 'both', labelsize = fs, length = 8,
                            bottom = True)
        cbar.ax.set_ylabel('Initial depth (km)', fontsize = fs)
        cbar.ax.set_yticks(depth_ticks)

      else:
        p1 = plotField(i, flat)

        # Add horizontal axis label
        grid[i].set_xlabel('Radial distance (km)', fontsize = fs)
        
        # Add color bar
        cbar = grid[i].cax.colorbar(p1)
        cbar = grid.cbar_axes[i].colorbar(p1)
        cbar.ax.set_ylabel('Total plastic strain', fontsize = fs)
        cbar.ax.tick_params(direction = 'in', which = 'both', labelsize = fs, length = 8,
                            bottom = True)
        cbar.ax.set_yticks(field_ticks)

  # Save the figure
  fig.savefig('{}/{}.{}'.format(dirname, figname, fig_type), dpi = dpi, bbox_inches = 'tight')

